mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor into test_init
This commit is contained in:
commit
b5fbf29385
|
@ -14,6 +14,9 @@
|
|||
namespace jittor {
|
||||
|
||||
extern cudnnHandle_t cudnn_handle;
|
||||
constexpr int max_cache_size=100;
|
||||
extern int max_cache_size;
|
||||
|
||||
// @pyjt(set_algorithm_cache_size)
|
||||
void set_algorithm_cache_size(int size);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -8,6 +8,11 @@
|
|||
namespace jittor {
|
||||
|
||||
cudnnHandle_t cudnn_handle;
|
||||
int max_cache_size = 100;
|
||||
|
||||
void set_algorithm_cache_size(int size) {
|
||||
max_cache_size = size;
|
||||
}
|
||||
|
||||
struct cudnn_initer {
|
||||
|
||||
|
|
|
@ -16,7 +16,8 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, \
|
||||
cudnn, curand, cublas
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
@ -699,7 +700,7 @@ def jittor_exit():
|
|||
atexit.register(jittor_exit)
|
||||
|
||||
Var.__str__ = lambda x: str(x.data)
|
||||
Var.__repr__ = lambda x: f"jt.Var:{x.dtype}{x.uncertain_shape}"
|
||||
Var.__repr__ = lambda x: str(x.data)
|
||||
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
||||
|
||||
from . import nn
|
||||
|
|
|
@ -132,6 +132,7 @@ def setup_cuda_extern():
|
|||
|
||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||
globals()[lib_name+"_ops"] = None
|
||||
globals()[lib_name] = None
|
||||
if not has_cuda: return
|
||||
LOG.v(f"setup {lib_name}...")
|
||||
|
||||
|
@ -157,9 +158,11 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
return
|
||||
|
||||
# compile and get operators
|
||||
culib_ops = compile_custom_ops(culib_src_files,
|
||||
culib = compile_custom_ops(culib_src_files, return_module=True,
|
||||
extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ")
|
||||
culib_ops = culib.ops
|
||||
globals()[lib_name+"_ops"] = culib_ops
|
||||
globals()[lib_name] = culib
|
||||
LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops)))
|
||||
|
||||
def install_cutt(root_folder):
|
||||
|
|
|
@ -58,6 +58,7 @@ def argmax_pool(x, size, stride, padding=0):
|
|||
def concat(arr, dim):
|
||||
# TODO: low performance when concat lots of vars
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
for a in arr:
|
||||
total_dim += a.shape[dim]
|
||||
cdim = 0
|
||||
|
|
|
@ -75,9 +75,21 @@ def linear(x, n):
|
|||
return jt.matmul(x, w) + b
|
||||
|
||||
def relu(x): return jt.maximum(x, 0)
|
||||
def leaky_relu(x, scale): return jt.ternary(x>0, x, x*scale)
|
||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
||||
|
||||
class PReLU(Module):
|
||||
def __init__(self, num_parameters=1, init_=0.25):
|
||||
self.num_parameters = num_parameters
|
||||
self.a = init.constant((num_parameters,), "float32", init_)
|
||||
|
||||
def execute(self, x):
|
||||
if self.num_parameters != 1:
|
||||
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
|
||||
return jt.maximum(0, x) + self.a.broadcast(x, [0,2,3]) * jt.minimum(0, x)
|
||||
else:
|
||||
return jt.maximum(0, x) + self.a * jt.minimum(0, x)
|
||||
|
||||
#TODO dims is 4 will cause slowly execution
|
||||
def cross_entropy_loss(output, target, ignore_index=None):
|
||||
if len(output.shape) == 4:
|
||||
|
@ -317,7 +329,7 @@ class BatchNorm(Module):
|
|||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
Leaky_relu = jt.make_module(leaky_relu, 0.01)
|
||||
Leaky_relu = jt.make_module(leaky_relu, 2)
|
||||
LeakyReLU = Leaky_relu
|
||||
ReLU6 = jt.make_module(relu6)
|
||||
Softmax = jt.make_module(softmax, 2)
|
||||
|
@ -401,7 +413,7 @@ class Conv(Module):
|
|||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
return y
|
||||
|
||||
|
||||
class ConvTranspose(Module):
|
||||
|
@ -456,6 +468,166 @@ class ConvTranspose(Module):
|
|||
return y
|
||||
|
||||
|
||||
class ReflectionPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
self.padding = padding
|
||||
if isinstance(self.padding, int):
|
||||
self.pl = self.padding
|
||||
self.pr = self.padding
|
||||
self.pt = self.padding
|
||||
self.pb = self.padding
|
||||
elif isinstance(self.padding, tuple):
|
||||
self.pl, self.pr, self.pt, self.pb = self.padding
|
||||
else:
|
||||
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width"
|
||||
assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height"
|
||||
oh=h+self.pt+self.pb
|
||||
ow=w+self.pl+self.pr
|
||||
l = self.pl
|
||||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i < l:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
|
||||
class ZeroPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
self.padding = padding
|
||||
if isinstance(self.padding, int):
|
||||
self.pl = self.padding
|
||||
self.pr = self.padding
|
||||
self.pt = self.padding
|
||||
self.pb = self.padding
|
||||
elif isinstance(self.padding, tuple):
|
||||
self.pl, self.pr, self.pt, self.pb = self.padding
|
||||
else:
|
||||
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"])
|
||||
|
||||
class ConstantPad2d(Module):
|
||||
def __init__(self, padding, value):
|
||||
self.padding = padding
|
||||
if isinstance(self.padding, int):
|
||||
self.pl = self.padding
|
||||
self.pr = self.padding
|
||||
self.pt = self.padding
|
||||
self.pb = self.padding
|
||||
elif isinstance(self.padding, tuple):
|
||||
self.pl, self.pr, self.pt, self.pb = self.padding
|
||||
else:
|
||||
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
|
||||
self.value = value
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value)
|
||||
|
||||
class ReplicationPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
self.padding = padding
|
||||
if isinstance(self.padding, int):
|
||||
self.pl = self.padding
|
||||
self.pr = self.padding
|
||||
self.pt = self.padding
|
||||
self.pb = self.padding
|
||||
elif isinstance(self.padding, tuple):
|
||||
self.pl, self.pr, self.pt, self.pb = self.padding
|
||||
else:
|
||||
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
oh=h+self.pt+self.pb
|
||||
ow=w+self.pl+self.pr
|
||||
l = self.pl
|
||||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = t
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = b
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = t
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = b
|
||||
elif i < l:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = t
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = b
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
r = self.upscale_factor
|
||||
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
||||
"i0",
|
||||
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
|
||||
f"i2/{r}",
|
||||
f"i3/{r}"
|
||||
])
|
||||
|
||||
class Tanh(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -474,8 +646,8 @@ def resize(x, size, mode="nearest"):
|
|||
H,W = size
|
||||
new_size = [n,c,H,W]
|
||||
nid, cid, hid, wid = jt.index(new_size)
|
||||
x = hid * ((h-1)/(H-1))
|
||||
y = wid * ((w-1)/(W-1))
|
||||
x = hid * h / H
|
||||
y = wid * w / W
|
||||
if mode=="nearest":
|
||||
return img.reindex([nid, cid, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
|
@ -493,7 +665,13 @@ def resize(x, size, mode="nearest"):
|
|||
return o
|
||||
raise(f"Not support {interpolation}")
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest'):
|
||||
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
||||
self.mode = mode
|
||||
|
||||
def execute(self, x):
|
||||
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
|
||||
|
||||
class Sequential(Module):
|
||||
def __init__(self, *args):
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer, is_train=True, threshold=1e-5):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
if is_train:
|
||||
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
||||
else:
|
||||
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
if is_train:
|
||||
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
||||
else:
|
||||
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
|
||||
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestBatchNorm(unittest.TestCase):
|
||||
def test_batchnorm(self):
|
||||
# ***************************************************************
|
||||
# Test BatchNorm Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.BatchNorm(10, is_train=True), tnn.BatchNorm2d(10))
|
||||
|
||||
class Model(tnn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = tnn.BatchNorm2d(10)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
model = Model()
|
||||
model.eval()
|
||||
check_equal(arr, jnn.BatchNorm(10, is_train=False), model, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
# ***************************************************************
|
||||
# Test ReplicationPad2d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ConstantPad2d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ZeroPad2d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1))
|
||||
check_equal(arr, jnn.ZeroPad2d((2,3,34,1)), tnn.ZeroPad2d((2,3,34,1)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ReflectionPad2d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
|
||||
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,66 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestRelu(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
# ***************************************************************
|
||||
# Test ReLU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.ReLU(), tnn.ReLU())
|
||||
|
||||
# ***************************************************************
|
||||
# Test PReLU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.PReLU(), tnn.PReLU())
|
||||
check_equal(arr, jnn.PReLU(10, 99.9), tnn.PReLU(10, 99.9))
|
||||
check_equal(arr, jnn.PReLU(10, 2), tnn.PReLU(10, 2))
|
||||
check_equal(arr, jnn.PReLU(10, -0.2), tnn.PReLU(10, -0.2))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ReLU6 Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.ReLU6(), tnn.ReLU6())
|
||||
|
||||
# ***************************************************************
|
||||
# Test LeakyReLU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
||||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -11,6 +11,17 @@ import jittor as jt
|
|||
import random
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
mid = 0
|
||||
if os.uname()[1] == "jittor-ce":
|
||||
mid = 1
|
||||
|
@ -74,6 +85,13 @@ def test_case(box_num, out_size, time_limit):
|
|||
assert fused_op_num == 1, fused_op_num
|
||||
assert t <= time_limit, t
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
class TestResizeAndCrop(unittest.TestCase):
|
||||
def test(self):
|
||||
test_case(100, [224, 224], 0.45)
|
||||
|
@ -81,5 +99,16 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
def test_upsample(self):
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(16,16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
arr = np.random.randn(1,16*16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -114,16 +114,31 @@ class Compose:
|
|||
return data
|
||||
|
||||
class Resize:
|
||||
def __init__(self, size):
|
||||
def __init__(self, size, mode=Image.BILINEAR):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
assert isinstance(size, tuple)
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
def __call__(self, img:Image.Image):
|
||||
return img.resize(self.size, Image.BILINEAR)
|
||||
return img.resize(self.size, self.mode)
|
||||
|
||||
class Gray:
|
||||
def __call__(self, img:Image.Image):
|
||||
img = np.array(img.convert('L'))
|
||||
img = img[np.newaxis, :]
|
||||
return np.array((img / 255.0), dtype = np.float32)
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, size):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
assert isinstance(size, tuple)
|
||||
self.size = size
|
||||
def __call__(self, img:Image.Image):
|
||||
width, height = img.size
|
||||
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop"
|
||||
top = np.random.randint(0,height-self.size[0]+1)
|
||||
left = np.random.randint(0,width-self.size[1]+1)
|
||||
return crop(img, top, left, self.size[0], self.size[1])
|
||||
|
|
@ -76,6 +76,7 @@ pjmap = {
|
|||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
'delete': ['inplace'],
|
||||
},
|
||||
'ReLU6': {
|
||||
'pytorch': {
|
||||
|
@ -88,6 +89,19 @@ pjmap = {
|
|||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
'delete': ['inplace'],
|
||||
},
|
||||
'PReLU': {
|
||||
'pytorch': {
|
||||
'args': 'num_parameters=1, init=0.25',
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'PReLU',
|
||||
'args': 'num_parameters=1, init_=0.25'
|
||||
},
|
||||
'links': {'init': 'init_'},
|
||||
'extras': {},
|
||||
},
|
||||
'LeakyReLU': {
|
||||
'pytorch': {
|
||||
|
@ -96,10 +110,11 @@ pjmap = {
|
|||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'LeakyReLU',
|
||||
'args': 'scale'
|
||||
'args': 'scale=0.01'
|
||||
},
|
||||
'links': {'negative_slope': 'scale'},
|
||||
'extras': {},
|
||||
'delete': ['inplace'],
|
||||
},
|
||||
'BatchNorm2d': {
|
||||
'pytorch': {
|
||||
|
@ -113,6 +128,19 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
'BatchNorm1d': {
|
||||
'pytorch': {
|
||||
'args': "num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True"
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'BatchNorm1d',
|
||||
'args': 'num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True',
|
||||
},
|
||||
'links': {},
|
||||
'extras': {'affine': 'None'},
|
||||
'delete': ['track_running_stats'],
|
||||
},
|
||||
'Dropout2d': {
|
||||
'pytorch': {
|
||||
'args': 'p=0.5, inplace=False',
|
||||
|
@ -124,6 +152,19 @@ pjmap = {
|
|||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
'delete': ['inplace'],
|
||||
},
|
||||
'Upsample': {
|
||||
'pytorch': {
|
||||
'args': "size=None, scale_factor=None, mode='nearest', align_corners=None",
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'Upsample',
|
||||
'args': "scale_factor=None, mode='nearest'"
|
||||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
'kaiming_normal_': {
|
||||
'pytorch': {
|
||||
|
@ -161,6 +202,18 @@ pjmap = {
|
|||
'links': {'tensor': 'var'},
|
||||
'extras': {},
|
||||
},
|
||||
'uniform_': {
|
||||
'pytorch': {
|
||||
'args': "tensor, a=0.0, b=1.0",
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'init',
|
||||
'name': 'uniform_',
|
||||
'args': 'var, low, high'
|
||||
},
|
||||
'links': {'tensor': 'var', 'a': 'low', 'b': 'high'},
|
||||
'extras': {},
|
||||
},
|
||||
'cat': {
|
||||
'pytorch': {
|
||||
'args': "tensors, dim=0, out=None",
|
||||
|
@ -225,7 +278,6 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
# 好像不需要如果一毛一样的话
|
||||
'view': {
|
||||
'pytorch': {
|
||||
'prefix': [],
|
||||
|
@ -244,15 +296,74 @@ pjmap = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
|
||||
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
|
||||
args:
|
||||
* `pytorch_func_name`: Pytorch function name
|
||||
* `pytorch_args`: Pytorch parameter list
|
||||
* `jittor_func_module`: to which module the Jittor function belongs
|
||||
* `jittor_func_name`: Jittor function name
|
||||
* `jittor_args`: Jittor parameter list
|
||||
* `extras`: parameter assignment
|
||||
* `links`: connection parameters
|
||||
* `delete`: delete parameters
|
||||
|
||||
example:
|
||||
from jittor.utils.pytorch_converter import pjmap_append
|
||||
pjmap_append(pytorch_func_name='AvgPool2d',
|
||||
pytorch_args='kernel_size, stride=None, padding=0, dilation=1, return_indices=False',
|
||||
jittor_func_module='nn',
|
||||
jittor_func_name='Pool',
|
||||
jittor_args='kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"',
|
||||
extras={"op": "'mean'"})
|
||||
'''
|
||||
if links == None: links = {}
|
||||
if extras == None: extras = {}
|
||||
if delete == None: delete = []
|
||||
assert isinstance(links, dict)
|
||||
assert isinstance(extras, dict)
|
||||
assert isinstance(delete, list)
|
||||
pjmap[pytorch_func_name] = {
|
||||
'pytorch': {
|
||||
'args': pytorch_args,
|
||||
},
|
||||
'jittor': {
|
||||
'module': jittor_func_module,
|
||||
'name': jittor_func_name,
|
||||
'args': jittor_args,
|
||||
},
|
||||
'links': links,
|
||||
'extras': extras,
|
||||
'delete': delete,
|
||||
}
|
||||
|
||||
unsupport_ops = [
|
||||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d', 'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'PReLU', 'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink', 'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss', 'BatchNorm1d', 'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'LocalResponseNorm', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear', 'Dropout3d', 'AlphaDropout', 'Embedding', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss', 'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss', 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss', 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_', 'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity', 'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity', 'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured', 'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss',
|
||||
'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
|
||||
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
|
||||
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
|
||||
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
|
||||
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
]
|
||||
|
||||
support_ops = {}
|
||||
|
@ -280,6 +391,30 @@ def replace(a):
|
|||
|
||||
import_flag = []
|
||||
def convert(code):
|
||||
''' Model code converter, example:
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
pytorch_code = """
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, 3)
|
||||
self.conv2 = nn.Conv2d(10, 32, 3)
|
||||
self.fc = nn.Linear(1200, 100)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
"""
|
||||
jittor_code = convert(pytorch_code)
|
||||
print("## Generate Jittor code:", jittor_code)
|
||||
exec(jittor_code)
|
||||
model = Model()
|
||||
print("## Jittor model:", model)
|
||||
'''
|
||||
a = ast.parse(code)
|
||||
dfs(a)
|
||||
a.body.insert(0, ast.parse('import jittor as jt').body[0])
|
||||
|
@ -298,6 +433,10 @@ def convert_(prefix, func_name, ags, kws):
|
|||
else:
|
||||
p_ags = info['pytorch']['args']
|
||||
j_ags = info['jittor']['args']
|
||||
if 'delete' in info.keys():
|
||||
delete = info['delete']
|
||||
else:
|
||||
delete = None
|
||||
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
|
||||
j_module = info['jittor']['module']
|
||||
j_name = info['jittor']['name']
|
||||
|
@ -333,6 +472,12 @@ def convert_(prefix, func_name, ags, kws):
|
|||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
jj_ags.append(d)
|
||||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
|
@ -395,6 +540,12 @@ def convert_(prefix, func_name, ags, kws):
|
|||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
j_ags_values.remove(d)
|
||||
if d in j_kws_values.keys():
|
||||
j_kws_values.pop(d)
|
||||
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
|
||||
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
|
||||
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
|
@ -412,10 +563,10 @@ def dfs(a):
|
|||
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
|
||||
import_flag.append('init')
|
||||
return ast.parse('from jittor import init').body[0]
|
||||
if 'torch' in astunparse.unparse(a) and 'nn' in astunparse.unparse(a):
|
||||
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
|
||||
import_flag.append('nn')
|
||||
return ast.parse('from jittor import nn').body[0]
|
||||
if a.names[0].name == 'torch':
|
||||
if 'torch' in a.names[0].name:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.ImportFrom):
|
||||
if 'torch' in a.module:
|
||||
|
@ -460,7 +611,6 @@ def dfs(a):
|
|||
ret = dfs(a_)
|
||||
if ret is 'delete':
|
||||
delete_flag.append(True)
|
||||
del a.__dict__[k][i]
|
||||
continue
|
||||
if ret is not None:
|
||||
a.__dict__[k][i] = ret
|
||||
|
@ -470,4 +620,4 @@ def dfs(a):
|
|||
else:
|
||||
ret = dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
||||
a.__dict__[k] = ret
|
|
@ -0,0 +1,122 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import numpy as np
|
||||
import jittor as jt
|
||||
import torch
|
||||
import time
|
||||
import jittor.models as jtmodels
|
||||
import torchvision.models as tcmodels
|
||||
import os
|
||||
|
||||
jt.flags.use_cuda = 1
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
jt.cudnn.set_algorithm_cache_size(10000)
|
||||
|
||||
threshold = 1e-3
|
||||
|
||||
models = [
|
||||
# 'squeezenet1_0',
|
||||
'squeezenet1_1',
|
||||
'alexnet',
|
||||
# 'resnet18',
|
||||
# 'resnet34',
|
||||
'resnet50',
|
||||
# 'resnet101',
|
||||
'resnet152',
|
||||
'resnext50_32x4d',
|
||||
'resnext101_32x8d',
|
||||
'vgg11',
|
||||
# 'vgg11_bn',
|
||||
# 'vgg13',
|
||||
# 'vgg13_bn',
|
||||
# 'vgg16',
|
||||
# 'vgg16_bn',
|
||||
# 'vgg19',
|
||||
# 'vgg19_bn',
|
||||
'wide_resnet50_2',
|
||||
'wide_resnet101_2',
|
||||
]
|
||||
|
||||
def to_cuda(x):
|
||||
if jt.has_cuda:
|
||||
return x.cuda()
|
||||
return x
|
||||
|
||||
def test_allmodels(bs=1):
|
||||
# Define numpy input image
|
||||
test_img = np.random.random((bs,3,224,224)).astype('float32')
|
||||
# Define pytorch & jittor input image
|
||||
pytorch_test_img = to_cuda(torch.Tensor(test_img))
|
||||
jittor_test_img = jt.array(test_img)
|
||||
for model in models:
|
||||
if model == "inception_v3":
|
||||
test_img = np.random.random((bs,3,300,300)).astype('float32')
|
||||
pytorch_test_img = to_cuda(torch.Tensor(test_img))
|
||||
jittor_test_img = jt.array(test_img)
|
||||
|
||||
jittor_test_img.stop_grad()
|
||||
pytorch_test_img.requires_grad = False
|
||||
|
||||
# Define pytorch & jittor model
|
||||
pytorch_model = to_cuda(tcmodels.__dict__[model]())
|
||||
jittor_model = jtmodels.__dict__[model]()
|
||||
# Set eval to avoid dropout layer
|
||||
pytorch_model.eval()
|
||||
jittor_model.eval()
|
||||
# Jittor loads pytorch parameters to ensure forward alignment
|
||||
jittor_model.load_parameters(pytorch_model.state_dict())
|
||||
|
||||
total = 512
|
||||
warmup = max(2, total // bs // 8)
|
||||
rerun = max(2, total // bs)
|
||||
|
||||
print("=" * 20 + model + "=" * 20)
|
||||
|
||||
# Jittor warms up
|
||||
for i in range(warmup):
|
||||
jittor_result = jittor_model(jittor_test_img)
|
||||
jt.sync_all(True)
|
||||
# Test jittor and once forward time
|
||||
sta = time.time()
|
||||
for i in range(rerun):
|
||||
jittor_result = jittor_model(jittor_test_img)
|
||||
jittor_result.sync()
|
||||
jt.sync_all(True)
|
||||
end = time.time()
|
||||
print(f"- Jittor {model} forward average time cost: {round((time.time() - sta) / rerun,5)}, Batch Size: {bs}, FPS: {round(bs * rerun / (end - sta),2)}")
|
||||
|
||||
# pytorch warmup
|
||||
for i in range(warmup):
|
||||
pytorch_result = pytorch_model(pytorch_test_img)
|
||||
# Test pytorch and once forward time
|
||||
torch.cuda.synchronize()
|
||||
sta = time.time()
|
||||
for i in range(rerun):
|
||||
pytorch_result = pytorch_model(pytorch_test_img)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
print(f"- Pytorch {model} forward average time cost: {round((end - sta) / rerun,5)}, Batch Size: {bs}, FPS: {round(bs * rerun / (end - sta),2)}")
|
||||
|
||||
# Judge pytorch & jittor forward relative error. If the differece is lower than threshold, this test passes.
|
||||
x = pytorch_result.detach().cpu().numpy() + 1
|
||||
y = jittor_result.numpy() + 1
|
||||
relative_error = abs(x - y) / abs(y)
|
||||
diff = relative_error.mean()
|
||||
assert diff < threshold, f"[*] {model} forward fails..., Relative Error: {diff}"
|
||||
print(f"[*] {model} forword passes with Relative Error {diff}")
|
||||
torch.cuda.empty_cache()
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for bs in [1,2,4,8,16,32,64,128]:
|
||||
# for bs in [128]:
|
||||
test_allmodels(bs)
|
Loading…
Reference in New Issue