mirror of https://github.com/Jittor/Jittor
depthwise_conv
This commit is contained in:
parent
14324f364b
commit
895f6ecee2
|
@ -98,6 +98,7 @@ def install_cub(root_folder):
|
||||||
return dirname
|
return dirname
|
||||||
|
|
||||||
def setup_cub():
|
def setup_cub():
|
||||||
|
global cub_home
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
|
||||||
cub_home = install_cub(cub_path)
|
cub_home = install_cub(cub_path)
|
||||||
|
|
|
@ -1,23 +1,27 @@
|
||||||
|
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
from jittor import init
|
from jittor import init
|
||||||
from jittor import nn
|
from jittor import nn
|
||||||
|
from jittor import Function
|
||||||
|
|
||||||
def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_conv=0):
|
class DepthwiseConv(Function):
|
||||||
N,C,H,W = data.shape
|
def __init__(self, stride=1, padding=0, dilation=1):
|
||||||
o,i,Kh,Kw = weights.shape
|
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||||
assert(i == 1)
|
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||||
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||||
|
|
||||||
|
def execute(self, x, weight):
|
||||||
|
self.save_vars = x, weight
|
||||||
|
N,C,H,W = x.shape
|
||||||
|
o,i,Kh,Kw = weight.shape
|
||||||
assert(o == C)
|
assert(o == C)
|
||||||
if not isinstance(stride, tuple):
|
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||||
stride = (stride, stride)
|
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||||
|
|
||||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
|
||||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
|
||||||
filter_height, filter_width = Kh, Kw
|
filter_height, filter_width = Kh, Kw
|
||||||
|
self.Khw = Kh, Kw
|
||||||
output = jt.code(
|
output = jt.code(
|
||||||
[N, C, oh, ow],
|
[N, C, oh, ow],
|
||||||
data.dtype,
|
x.dtype,
|
||||||
[data, weights],
|
[x, weight],
|
||||||
cuda_header = """
|
cuda_header = """
|
||||||
template <typename T,
|
template <typename T,
|
||||||
int filter_height,
|
int filter_height,
|
||||||
|
@ -76,6 +80,61 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
""",
|
||||||
|
cuda_src=f"""
|
||||||
|
@alias(input, in0)
|
||||||
|
@alias(filter, in1)
|
||||||
|
@alias(output, out)
|
||||||
|
|
||||||
|
const int batch_size = input_shape0;
|
||||||
|
const int input_channels = input_shape1;
|
||||||
|
const int input_height = input_shape2;
|
||||||
|
const int input_width = input_shape3;
|
||||||
|
const int output_channels = output_shape1;
|
||||||
|
const int output_height = output_shape2;
|
||||||
|
const int output_width = output_shape3;
|
||||||
|
const int ksize_height = {Kh};
|
||||||
|
const int ksize_width = {Kw};
|
||||||
|
const int stride_height = {self.stride[0]};
|
||||||
|
const int stride_width = {self.stride[1]};
|
||||||
|
const int padding_height = {self.padding[0]};
|
||||||
|
const int padding_width = {self.padding[1]};
|
||||||
|
const int dilate_height = {self.dilation[0]};
|
||||||
|
const int dilate_width = {self.dilation[1]};
|
||||||
|
|
||||||
|
int thread = 512;
|
||||||
|
if (output_width > 1024 && output_width <= 2048)
|
||||||
|
thread = (output_width - 1) / 2 + 1;
|
||||||
|
else if (output_width > 512 && output_width <= 1024)
|
||||||
|
thread = output_width;
|
||||||
|
int blocks = std::min(std::max(thread / output_width, 1), output_height);
|
||||||
|
dim3 threads(std::min(output_width, thread), blocks, 1);
|
||||||
|
dim3 grid(output_channels, batch_size, 1);
|
||||||
|
KernelDepthwiseConv<
|
||||||
|
input_type, ksize_height, ksize_width,
|
||||||
|
stride_height, stride_width>
|
||||||
|
<<<grid, threads>>>(
|
||||||
|
input_p, filter_p, batch_size, output_channels, output_height,
|
||||||
|
output_width, input_channels, input_height, input_width,
|
||||||
|
padding_height, padding_width, dilate_height,
|
||||||
|
dilate_width, output_p);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def grad(self, grad):
|
||||||
|
x, weight = self.save_vars
|
||||||
|
Kh, Kw = self.Khw
|
||||||
|
return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad],
|
||||||
|
cuda_header = f"#include <{jt.compile_extern.cub_home}/cub/cub.cuh>"+"""
|
||||||
|
template <typename T>
|
||||||
|
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
|
||||||
|
typedef cub::WarpReduce<T> WarpReduce;
|
||||||
|
typename WarpReduce::TempStorage temp_storage;
|
||||||
|
value = WarpReduce(temp_storage).Sum(value);
|
||||||
|
if (cub::LaneId() == 0)
|
||||||
|
atomicAdd(sum, value);
|
||||||
|
}
|
||||||
|
|
||||||
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
|
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
|
||||||
template <typename T,
|
template <typename T,
|
||||||
|
@ -153,11 +212,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
const int dilate_width, T* filter_grad_data) {
|
const int dilate_width, T* filter_grad_data) {
|
||||||
T s = 0;
|
T s = 0;
|
||||||
|
|
||||||
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
|
int gbid = (((blockIdx.z * blockDim.z + threadIdx.z) * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
|
||||||
|
|
||||||
for (int image_w = threadIdx.x; image_w < output_width;
|
for (int image_w = threadIdx.x; image_w < output_width;
|
||||||
image_w += blockDim.x) {
|
image_w += blockDim.x) {
|
||||||
for (int bid = 0; bid < num; bid++) {
|
for (int bid = 0; bid < num; bid++) {
|
||||||
|
//for (int bid = threadIdx.z; bid < num; bid+=blockDim.z) {
|
||||||
for (int image_h = threadIdx.y; image_h < output_height;
|
for (int image_h = threadIdx.y; image_h < output_height;
|
||||||
image_h += blockDim.y) {
|
image_h += blockDim.y) {
|
||||||
int kernel_id = blockIdx.z;
|
int kernel_id = blockIdx.z;
|
||||||
|
@ -183,56 +243,16 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
atomicAdd(&filter_grad_data[gbid], s);
|
CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
|
|
||||||
cuda_src=f"""
|
cuda_src=f"""
|
||||||
@alias(input, in0)
|
|
||||||
@alias(filter, in1)
|
|
||||||
@alias(output, out0)
|
|
||||||
|
|
||||||
const int batch_size = input_shape0;
|
|
||||||
const int input_channels = input_shape1;
|
|
||||||
const int input_height = input_shape2;
|
|
||||||
const int input_width = input_shape3;
|
|
||||||
const int output_channels = output_shape1;
|
|
||||||
const int output_height = output_shape2;
|
|
||||||
const int output_width = output_shape3;
|
|
||||||
const int ksize_height = {Kh};
|
|
||||||
const int ksize_width = {Kw};
|
|
||||||
const int stride_height = {stride[0]};
|
|
||||||
const int stride_width = {stride[1]};
|
|
||||||
const int padding_height = {padding[0]};
|
|
||||||
const int padding_width = {padding[1]};
|
|
||||||
const int dilate_height = {dilation[0]};
|
|
||||||
const int dilate_width = {dilation[1]};
|
|
||||||
|
|
||||||
int thread = 512;
|
|
||||||
if (output_width > 1024 && output_width <= 2048)
|
|
||||||
thread = (output_width - 1) / 2 + 1;
|
|
||||||
else if (output_width > 512 && output_width <= 1024)
|
|
||||||
thread = output_width;
|
|
||||||
int blocks = std::min(std::max(thread / output_width, 1), output_height);
|
|
||||||
dim3 threads(std::min(output_width, thread), blocks, 1);
|
|
||||||
dim3 grid(output_channels, batch_size, 1);
|
|
||||||
KernelDepthwiseConv<
|
|
||||||
input_type, ksize_height, ksize_width,
|
|
||||||
stride_height, stride_width>
|
|
||||||
<<<grid, threads>>>(
|
|
||||||
input_p, filter_p, batch_size, output_channels, output_height,
|
|
||||||
output_width, input_channels, input_height, input_width,
|
|
||||||
padding_height, padding_width, dilate_height,
|
|
||||||
dilate_width, output_p);
|
|
||||||
""",
|
|
||||||
|
|
||||||
cuda_grad_src=[
|
|
||||||
f"""
|
|
||||||
// source for backward to data
|
// source for backward to data
|
||||||
@alias(input, in0)
|
@alias(input, in0)
|
||||||
@alias(filter, in1)
|
@alias(filter, in1)
|
||||||
@alias(input_grad, out)
|
|
||||||
@alias(output_grad, in2)
|
@alias(output_grad, in2)
|
||||||
|
@alias(input_grad, out0)
|
||||||
|
@alias(filter_grad, out1)
|
||||||
|
|
||||||
const int batch_size = input_shape0;
|
const int batch_size = input_shape0;
|
||||||
const int input_channels = input_shape1;
|
const int input_channels = input_shape1;
|
||||||
|
@ -243,12 +263,12 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
const int output_width = output_grad_shape3;
|
const int output_width = output_grad_shape3;
|
||||||
const int ksize_height = {Kh};
|
const int ksize_height = {Kh};
|
||||||
const int ksize_width = {Kw};
|
const int ksize_width = {Kw};
|
||||||
const int stride_height = {stride[0]};
|
const int stride_height = {self.stride[0]};
|
||||||
const int stride_width = {stride[1]};
|
const int stride_width = {self.stride[1]};
|
||||||
const int padding_height = {padding[0]};
|
const int padding_height = {self.padding[0]};
|
||||||
const int padding_width = {padding[1]};
|
const int padding_width = {self.padding[1]};
|
||||||
const int dilate_height = {dilation[0]};
|
const int dilate_height = {self.dilation[0]};
|
||||||
const int dilate_width = {dilation[1]};
|
const int dilate_width = {self.dilation[1]};
|
||||||
|
|
||||||
int thread = 512;
|
int thread = 512;
|
||||||
if (input_width > 1024 && input_width <= 2048)
|
if (input_width > 1024 && input_width <= 2048)
|
||||||
|
@ -266,28 +286,8 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
output_channels, output_height, output_width, input_channels,
|
output_channels, output_height, output_width, input_channels,
|
||||||
input_height, input_width, padding_height,
|
input_height, input_width, padding_height,
|
||||||
padding_width, dilate_height, dilate_width, input_grad_p);
|
padding_width, dilate_height, dilate_width, input_grad_p);
|
||||||
""",
|
|
||||||
f"""
|
|
||||||
// source for backward to filter
|
|
||||||
@alias(input, in0)
|
|
||||||
@alias(filter_grad, out)
|
|
||||||
@alias(output_grad, in2)
|
|
||||||
|
|
||||||
const int batch_size = input_shape0;
|
// source for backward to filter
|
||||||
const int input_channels = input_shape1;
|
|
||||||
const int input_height = input_shape2;
|
|
||||||
const int input_width = input_shape3;
|
|
||||||
const int output_channels = output_grad_shape1;
|
|
||||||
const int output_height = output_grad_shape2;
|
|
||||||
const int output_width = output_grad_shape3;
|
|
||||||
const int ksize_height = {Kh};
|
|
||||||
const int ksize_width = {Kw};
|
|
||||||
const int stride_height = {stride[0]};
|
|
||||||
const int stride_width = {stride[1]};
|
|
||||||
const int padding_height = {padding[0]};
|
|
||||||
const int padding_width = {padding[1]};
|
|
||||||
const int dilate_height = {dilation[0]};
|
|
||||||
const int dilate_width = {dilation[1]};
|
|
||||||
|
|
||||||
int block_size = 512;
|
int block_size = 512;
|
||||||
if (output_width > 1024 && output_width <= 2048)
|
if (output_width > 1024 && output_width <= 2048)
|
||||||
|
@ -296,8 +296,10 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
block_size = output_width;
|
block_size = output_width;
|
||||||
int crop_output_height =
|
int crop_output_height =
|
||||||
std::min(std::max(block_size / output_width, 1), output_height);
|
std::min(std::max(block_size / output_width, 1), output_height);
|
||||||
dim3 grid(ksize_width, ksize_height, output_channels);
|
|
||||||
dim3 threads(std::min(output_width, block_size), crop_output_height, 1);
|
grid = dim3(ksize_width, ksize_height, output_channels);
|
||||||
|
threads = dim3(std::min(output_width, block_size), crop_output_height, 1);
|
||||||
|
|
||||||
|
|
||||||
KernelDepthwiseConvFilterGrad<
|
KernelDepthwiseConvFilterGrad<
|
||||||
input_type><<<grid, threads, 0>>>(
|
input_type><<<grid, threads, 0>>>(
|
||||||
|
@ -307,6 +309,4 @@ def depthwise_conv(data, weights, padding, dilation, stride, fuse_relu_before_co
|
||||||
stride_height, stride_width, padding_height, padding_width,
|
stride_height, stride_width, padding_height, padding_width,
|
||||||
dilate_height, dilate_width, filter_grad_p);
|
dilate_height, dilate_width, filter_grad_p);
|
||||||
"""
|
"""
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return output
|
|
|
@ -153,7 +153,7 @@ def get_init_var_rand(shape, dtype):
|
||||||
|
|
||||||
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
|
||||||
def sign(x):
|
def sign(x):
|
||||||
one = jt.ones(x.shape)
|
one = jt.ones(x.shape)
|
||||||
x = jt.ternary(x>0, one, x)
|
x = jt.ternary(x>0, one, x)
|
||||||
|
@ -473,7 +473,7 @@ ReLU6 = jt.make_module(relu6)
|
||||||
Softmax = jt.make_module(softmax, 2)
|
Softmax = jt.make_module(softmax, 2)
|
||||||
GELU = jt.make_module(gelu)
|
GELU = jt.make_module(gelu)
|
||||||
|
|
||||||
from jittor.depthwise_conv import depthwise_conv
|
from jittor.depthwise_conv import DepthwiseConv
|
||||||
|
|
||||||
class Conv(Module):
|
class Conv(Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||||
|
@ -485,7 +485,8 @@ class Conv(Module):
|
||||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
|
self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels
|
||||||
# self.is_depthwise_conv = False
|
if self.is_depthwise_conv and jt.flags.use_cuda:
|
||||||
|
self.depthwise_conv = DepthwiseConv(stride, padding, dilation)
|
||||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||||
Kh, Kw = self.kernel_size
|
Kh, Kw = self.kernel_size
|
||||||
|
@ -505,8 +506,12 @@ class Conv(Module):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
if self.is_depthwise_conv:
|
if self.is_depthwise_conv and jt.flags.use_cuda:
|
||||||
return depthwise_conv(x, self.weight, self.padding, self.dilation, self.stride)
|
y = self.depthwise_conv(x, self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||||
|
y = y + b
|
||||||
|
return y
|
||||||
elif self.groups == 1:
|
elif self.groups == 1:
|
||||||
N,C,H,W = x.shape
|
N,C,H,W = x.shape
|
||||||
Kh, Kw = self.kernel_size
|
Kh, Kw = self.kernel_size
|
||||||
|
@ -541,7 +546,6 @@ class Conv(Module):
|
||||||
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
||||||
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
||||||
])
|
])
|
||||||
xx.compile_options = {"G":G}
|
|
||||||
# w: [oc, CpG, Kh, Kw]
|
# w: [oc, CpG, Kh, Kw]
|
||||||
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
||||||
f'i1*{oc//G}+i2',
|
f'i1*{oc//G}+i2',
|
||||||
|
@ -549,6 +553,7 @@ class Conv(Module):
|
||||||
'i6',
|
'i6',
|
||||||
'i7'
|
'i7'
|
||||||
])
|
])
|
||||||
|
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
||||||
yy = xx*ww
|
yy = xx*ww
|
||||||
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
||||||
'i0',
|
'i0',
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
import jittor.models as jtmodels
|
||||||
|
|
||||||
|
def load_parameters(m1, m2):
|
||||||
|
m1.save('temp.pk')
|
||||||
|
m2.load('temp.pk')
|
||||||
|
|
||||||
|
def compare_parameters(m1, m2):
|
||||||
|
ps1 = m1.parameters()
|
||||||
|
ps2 = m2.parameters()
|
||||||
|
for i in range(len(ps1)):
|
||||||
|
x = ps1[i].data + 1e-8
|
||||||
|
y = ps2[i].data + 1e-8
|
||||||
|
relative_error = abs(x - y) / abs(y)
|
||||||
|
diff = relative_error.mean()
|
||||||
|
assert diff < 1e-4, (diff, 'backward', ps2[i].name())
|
||||||
|
|
||||||
|
class TestDepthwiseConv(unittest.TestCase):
|
||||||
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_data(self):
|
||||||
|
test_img = np.random.random((64,3,224,224)).astype('float32')
|
||||||
|
jittor_test_img = jt.array(test_img)
|
||||||
|
lr = 100
|
||||||
|
|
||||||
|
jittor_model = jtmodels.__dict__['mobilenet_v2']()
|
||||||
|
jittor_model2 = jtmodels.__dict__['mobilenet_v2']()
|
||||||
|
# Set eval to avoid dropout layer & bn errors
|
||||||
|
jittor_model.train()
|
||||||
|
jittor_model.classifier[0].eval()
|
||||||
|
for m in jittor_model.modules():
|
||||||
|
if isinstance(m, jt.nn.BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
jittor_model2.train()
|
||||||
|
jittor_model2.classifier[0].eval()
|
||||||
|
for m in jittor_model2.modules():
|
||||||
|
if isinstance(m, jt.nn.BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
load_parameters(jittor_model2, jittor_model)
|
||||||
|
for m in jittor_model.modules():
|
||||||
|
if isinstance(m, jt.nn.Conv):
|
||||||
|
m.is_depthwise_conv = False
|
||||||
|
cnt = 0
|
||||||
|
for m in jittor_model2.modules():
|
||||||
|
if isinstance(m, jt.nn.Conv):
|
||||||
|
if (m.is_depthwise_conv):
|
||||||
|
cnt += 1
|
||||||
|
assert cnt == 17, (cnt, '!=', 17)
|
||||||
|
jt_optimizer = jt.nn.SGD(jittor_model.parameters(), lr = lr)
|
||||||
|
jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr)
|
||||||
|
|
||||||
|
jittor_result = jittor_model(jittor_test_img)
|
||||||
|
loss = jittor_result.sum()
|
||||||
|
jt_optimizer.step(loss)
|
||||||
|
jt.sync_all(True)
|
||||||
|
|
||||||
|
jittor_result2 = jittor_model2(jittor_test_img)
|
||||||
|
loss = jittor_result2.sum()
|
||||||
|
jt_optimizer2.step(loss)
|
||||||
|
jt.sync_all(True)
|
||||||
|
compare_parameters(jittor_model, jittor_model2)
|
||||||
|
|
||||||
|
x = jittor_result2.data + 1e-8
|
||||||
|
y = jittor_result.data + 1e-8
|
||||||
|
relative_error = abs(x - y) / abs(y)
|
||||||
|
diff = relative_error.mean()
|
||||||
|
assert diff < 1e-4, (diff, 'forword')
|
||||||
|
|
||||||
|
jt.clean()
|
||||||
|
jt.gc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue