Merge pull request #73 from Jittor/zwy

add some ops
This commit is contained in:
zhouwy19 2020-04-30 22:18:46 +08:00 committed by GitHub
commit 67d7053322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 522 additions and 18 deletions

View File

@ -699,7 +699,7 @@ def jittor_exit():
atexit.register(jittor_exit)
Var.__str__ = lambda x: str(x.data)
Var.__repr__ = lambda x: f"jt.Var:{x.dtype}{x.uncertain_shape}"
Var.__repr__ = lambda x: str(x.data)
Var.peek = lambda x: f"{x.dtype}{x.shape}"
from . import nn

View File

@ -58,6 +58,7 @@ def argmax_pool(x, size, stride, padding=0):
def concat(arr, dim):
# TODO: low performance when concat lots of vars
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr:
total_dim += a.shape[dim]
cdim = 0

View File

@ -75,9 +75,21 @@ def linear(x, n):
return jt.matmul(x, w) + b
def relu(x): return jt.maximum(x, 0)
def leaky_relu(x, scale): return jt.ternary(x>0, x, x*scale)
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
class PReLU(Module):
def __init__(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.a = init.constant((num_parameters,), "float32", init_)
def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
return jt.maximum(0, x) + self.a.broadcast(x, [0,2,3]) * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.a * jt.minimum(0, x)
#TODO dims is 4 will cause slowly execution
def cross_entropy_loss(output, target, ignore_index=None):
if len(output.shape) == 4:
@ -317,7 +329,7 @@ class BatchNorm(Module):
Relu = jt.make_module(relu)
ReLU = Relu
Leaky_relu = jt.make_module(leaky_relu, 0.01)
Leaky_relu = jt.make_module(leaky_relu, 2)
LeakyReLU = Leaky_relu
ReLU6 = jt.make_module(relu6)
Softmax = jt.make_module(softmax, 2)
@ -396,7 +408,7 @@ class Conv(Module):
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
return y
class ConvTranspose(Module):
@ -451,6 +463,166 @@ class ConvTranspose(Module):
return y
class ReflectionPad2d(Module):
def __init__(self, padding):
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
self.pr = self.padding
self.pt = self.padding
self.pb = self.padding
elif isinstance(self.padding, tuple):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}")
def execute(self, x):
n,c,h,w = x.shape
assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width"
assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height"
oh=h+self.pt+self.pb
ow=w+self.pl+self.pr
l = self.pl
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * t - j
elif i < l and j > b:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * b - j
elif i > r and j < t:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * t - j
elif i > r and j > b:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * b - j
elif i < l:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = 2 * t - j
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = 2 * b - j
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
class ZeroPad2d(Module):
def __init__(self, padding):
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
self.pr = self.padding
self.pt = self.padding
self.pb = self.padding
elif isinstance(self.padding, tuple):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}")
def execute(self, x):
n,c,h,w = x.shape
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"])
class ConstantPad2d(Module):
def __init__(self, padding, value):
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
self.pr = self.padding
self.pt = self.padding
self.pb = self.padding
elif isinstance(self.padding, tuple):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}")
self.value = value
def execute(self, x):
n,c,h,w = x.shape
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value)
class ReplicationPad2d(Module):
def __init__(self, padding):
self.padding = padding
if isinstance(self.padding, int):
self.pl = self.padding
self.pr = self.padding
self.pt = self.padding
self.pb = self.padding
elif isinstance(self.padding, tuple):
self.pl, self.pr, self.pt, self.pb = self.padding
else:
raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}")
def execute(self, x):
n,c,h,w = x.shape
oh=h+self.pt+self.pb
ow=w+self.pl+self.pr
l = self.pl
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = l
y_idx[j,i] = t
elif i < l and j > b:
x_idx[j,i] = l
y_idx[j,i] = b
elif i > r and j < t:
x_idx[j,i] = r
y_idx[j,i] = t
elif i > r and j > b:
x_idx[j,i] = r
y_idx[j,i] = b
elif i < l:
x_idx[j,i] = l
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = r
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = t
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = b
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor
def execute(self, x):
n,c,h,w = x.shape
r = self.upscale_factor
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
return x.reindex([n,int(c/r**2),h*r,w*r], [
"i0",
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
f"i2/{r}",
f"i3/{r}"
])
class Tanh(Module):
def __init__(self):
super().__init__()
@ -469,8 +641,8 @@ def resize(x, size, mode="nearest"):
H,W = size
new_size = [n,c,H,W]
nid, cid, hid, wid = jt.index(new_size)
x = hid * ((h-1)/(H-1))
y = wid * ((w-1)/(W-1))
x = hid * h / H
y = wid * w / W
if mode=="nearest":
return img.reindex([nid, cid, x.floor(), y.floor()])
if mode=="bilinear":
@ -488,7 +660,13 @@ def resize(x, size, mode="nearest"):
return o
raise(f"Not support {interpolation}")
class Upsample(Module):
def __init__(self, scale_factor=None, mode='nearest'):
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
self.mode = mode
def execute(self, x):
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
class Sequential(Module):
def __init__(self, *args):

View File

@ -0,0 +1,66 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer, is_train=True, threshold=1e-5):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
if is_train:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
if is_train:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestBatchNorm(unittest.TestCase):
def test_batchnorm(self):
# ***************************************************************
# Test BatchNorm Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.BatchNorm(10, is_train=True), tnn.BatchNorm2d(10))
class Model(tnn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = tnn.BatchNorm2d(10)
def forward(self, x):
return self.layer(x)
model = Model()
model.eval()
check_equal(arr, jnn.BatchNorm(10, is_train=False), model, False)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,65 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
def test_pad(self):
# ***************************************************************
# Test ReplicationPad2d Layer
# ***************************************************************
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
# ***************************************************************
# Test ConstantPad2d Layer
# ***************************************************************
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
# ***************************************************************
# Test ZeroPad2d Layer
# ***************************************************************
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1))
check_equal(arr, jnn.ZeroPad2d((2,3,34,1)), tnn.ZeroPad2d((2,3,34,1)))
# ***************************************************************
# Test ReflectionPad2d Layer
# ***************************************************************
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,66 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
@unittest.skipIf(skip_this_test, "No Torch found")
class TestRelu(unittest.TestCase):
def test_relu(self):
# ***************************************************************
# Test ReLU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.ReLU(), tnn.ReLU())
# ***************************************************************
# Test PReLU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.PReLU(), tnn.PReLU())
check_equal(arr, jnn.PReLU(10, 99.9), tnn.PReLU(10, 99.9))
check_equal(arr, jnn.PReLU(10, 2), tnn.PReLU(10, 2))
check_equal(arr, jnn.PReLU(10, -0.2), tnn.PReLU(10, -0.2))
# ***************************************************************
# Test ReLU6 Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.ReLU6(), tnn.ReLU6())
# ***************************************************************
# Test LeakyReLU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
if __name__ == "__main__":
unittest.main()

View File

@ -11,6 +11,17 @@ import jittor as jt
import random
import os
import numpy as np
import jittor.nn as jnn
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
mid = 0
if os.uname()[1] == "jittor-ce":
mid = 1
@ -74,6 +85,13 @@ def test_case(box_num, out_size, time_limit):
assert fused_op_num == 1, fused_op_num
assert t <= time_limit, t
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
class TestResizeAndCrop(unittest.TestCase):
def test(self):
test_case(100, [224, 224], 0.45)
@ -81,5 +99,16 @@ class TestResizeAndCrop(unittest.TestCase):
test_case(20, [1024, 1024], [1.2, 1.8][mid])
test_case(20, [1024, 666], [0.8,1.0][mid])
def test_upsample(self):
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
def test_pixelshuffle(self):
arr = np.random.randn(16,16,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
arr = np.random.randn(1,16*16,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
if __name__ == "__main__":
unittest.main()

View File

@ -114,16 +114,31 @@ class Compose:
return data
class Resize:
def __init__(self, size):
def __init__(self, size, mode=Image.BILINEAR):
if isinstance(size, int):
size = (size, size)
assert isinstance(size, tuple)
self.size = size
self.mode = mode
def __call__(self, img:Image.Image):
return img.resize(self.size, Image.BILINEAR)
return img.resize(self.size, self.mode)
class Gray:
def __call__(self, img:Image.Image):
img = np.array(img.convert('L'))
img = img[np.newaxis, :]
return np.array((img / 255.0), dtype = np.float32)
class RandomCrop:
def __init__(self, size):
if isinstance(size, int):
size = (size, size)
assert isinstance(size, tuple)
self.size = size
def __call__(self, img:Image.Image):
width, height = img.size
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop"
top = np.random.randint(0,height-self.size[0]+1)
left = np.random.randint(0,width-self.size[1]+1)
return crop(img, top, left, self.size[0], self.size[1])

View File

@ -76,6 +76,7 @@ pjmap = {
},
'links': {},
'extras': {},
'delete': ['inplace'],
},
'ReLU6': {
'pytorch': {
@ -88,6 +89,19 @@ pjmap = {
},
'links': {},
'extras': {},
'delete': ['inplace'],
},
'PReLU': {
'pytorch': {
'args': 'num_parameters=1, init=0.25',
},
'jittor': {
'module': 'nn',
'name': 'PReLU',
'args': 'num_parameters=1, init_=0.25'
},
'links': {'init': 'init_'},
'extras': {},
},
'LeakyReLU': {
'pytorch': {
@ -96,10 +110,11 @@ pjmap = {
'jittor': {
'module': 'nn',
'name': 'LeakyReLU',
'args': 'scale'
'args': 'scale=0.01'
},
'links': {'negative_slope': 'scale'},
'extras': {},
'delete': ['inplace'],
},
'BatchNorm2d': {
'pytorch': {
@ -113,6 +128,19 @@ pjmap = {
'links': {},
'extras': {},
},
'BatchNorm1d': {
'pytorch': {
'args': "num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True"
},
'jittor': {
'module': 'nn',
'name': 'BatchNorm1d',
'args': 'num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True',
},
'links': {},
'extras': {'affine': 'None'},
'delete': ['track_running_stats'],
},
'Dropout2d': {
'pytorch': {
'args': 'p=0.5, inplace=False',
@ -124,6 +152,19 @@ pjmap = {
},
'links': {},
'extras': {},
'delete': ['inplace'],
},
'Upsample': {
'pytorch': {
'args': "size=None, scale_factor=None, mode='nearest', align_corners=None",
},
'jittor': {
'module': 'nn',
'name': 'Upsample',
'args': "scale_factor=None, mode='nearest'"
},
'links': {},
'extras': {},
},
'kaiming_normal_': {
'pytorch': {
@ -161,6 +202,18 @@ pjmap = {
'links': {'tensor': 'var'},
'extras': {},
},
'uniform_': {
'pytorch': {
'args': "tensor, a=0.0, b=1.0",
},
'jittor': {
'module': 'init',
'name': 'uniform_',
'args': 'var, low, high'
},
'links': {'tensor': 'var', 'a': 'low', 'b': 'high'},
'extras': {},
},
'cat': {
'pytorch': {
'args': "tensors, dim=0, out=None",
@ -225,7 +278,6 @@ pjmap = {
'links': {},
'extras': {},
},
# 好像不需要如果一毛一样的话
'view': {
'pytorch': {
'prefix': [],
@ -250,9 +302,26 @@ unsupport_ops = [
# ***************************************************************
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d', 'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
'PReLU', 'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink', 'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss', 'BatchNorm1d', 'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'LocalResponseNorm', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear', 'Dropout3d', 'AlphaDropout', 'Embedding', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss', 'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss', 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss', 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_', 'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity', 'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity', 'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured', 'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink',
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss',
'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss',
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
]
support_ops = {}
@ -298,6 +367,10 @@ def convert_(prefix, func_name, ags, kws):
else:
p_ags = info['pytorch']['args']
j_ags = info['jittor']['args']
if 'delete' in info.keys():
delete = info['delete']
else:
delete = None
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
j_module = info['jittor']['module']
j_name = info['jittor']['name']
@ -333,6 +406,12 @@ def convert_(prefix, func_name, ags, kws):
pp_ags.append(p_ag)
if len(jj_ags) == 0 and len(pp_ags) != 0:
raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
if delete is not None:
for d in delete:
if d in pp_ags:
jj_ags.append(d)
if d in pp_kws.keys():
jj_kws[d] = None
if len(pp_ags) > len(ags) + len(kws):
raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
ags_ = []
@ -395,6 +474,12 @@ def convert_(prefix, func_name, ags, kws):
j_kws_values[k] = extras[k]
else:
raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
if delete is not None:
for d in delete:
if d in j_ags_values:
j_ags_values.remove(d)
if d in j_kws_values.keys():
j_kws_values.pop(d)
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
@ -412,10 +497,10 @@ def dfs(a):
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
import_flag.append('init')
return ast.parse('from jittor import init').body[0]
if 'torch' in astunparse.unparse(a) and 'nn' in astunparse.unparse(a):
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
import_flag.append('nn')
return ast.parse('from jittor import nn').body[0]
if a.names[0].name == 'torch':
if 'torch' in a.names[0].name:
return 'delete'
elif isinstance(a, ast.ImportFrom):
if 'torch' in a.module:
@ -460,7 +545,6 @@ def dfs(a):
ret = dfs(a_)
if ret is 'delete':
delete_flag.append(True)
del a.__dict__[k][i]
continue
if ret is not None:
a.__dict__[k][i] = ret
@ -470,4 +554,4 @@ def dfs(a):
else:
ret = dfs(a.__dict__[k])
if ret is not None:
a.__dict__[k] = ret
a.__dict__[k] = ret