This commit is contained in:
zwy 2020-04-29 21:02:51 +08:00
parent 4ff60ee04c
commit d74dc1c148
6 changed files with 49 additions and 28 deletions

View File

@ -602,6 +602,21 @@ class ReplicationPad2d(Module):
y_idx[j,i] = b
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor
def execute(self, x):
n,c,h,w = x.shape
r = self.upscale_factor
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
return x.reindex([n,int(c/r**2),h*r,w*r], [
"i0",
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
f"i2/{r}",
f"i3/{r}"
])
class Tanh(Module):
def __init__(self):
super().__init__()

View File

@ -19,6 +19,7 @@ try:
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer, threshold=1e-5):
jittor_arr = jt.array(arr)
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer, threshold=1e-5):
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestBatchNorm(unittest.TestCase):
def test_batchnorm(self):
# ***************************************************************

View File

@ -19,6 +19,7 @@ try:
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer):
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
@unittest.skipIf(skip_this_test, "No Torch found")
class TestPad(unittest.TestCase):
def test_pad(self):
# ***************************************************************

View File

@ -19,6 +19,7 @@ try:
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer):
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
@unittest.skipIf(skip_this_test, "No Torch found")
class TestRelu(unittest.TestCase):
def test_relu(self):
# ***************************************************************

View File

@ -20,6 +20,7 @@ try:
except:
torch = None
tnn = None
skip_this_test = True
mid = 0
if os.uname()[1] == "jittor-ce":
@ -84,6 +85,13 @@ def test_case(box_num, out_size, time_limit):
assert fused_op_num == 1, fused_op_num
assert t <= time_limit, t
def check_equal(arr, j_layer, p_layer):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
class TestResizeAndCrop(unittest.TestCase):
def test(self):
test_case(100, [224, 224], 0.45)
@ -92,22 +100,15 @@ class TestResizeAndCrop(unittest.TestCase):
test_case(20, [1024, 666], [0.8,1.0][mid])
def test_upsample(self):
# ***************************************************************
# Define jittor & pytorch array
# ***************************************************************
arr = np.random.randn(16,10,224,224)
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
# ***************************************************************
# Test Upsample Layer
# ***************************************************************
pytorch_result = tnn.Upsample(scale_factor=2)(pytorch_arr)
jittor_result = jnn.Upsample(scale_factor=2)(jittor_arr)
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
pytorch_result = tnn.Upsample(scale_factor=0.2)(pytorch_arr)
jittor_result = jnn.Upsample(scale_factor=0.2)(jittor_arr)
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
def test_pixelshuffle(self):
arr = np.random.randn(16,16,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
arr = np.random.randn(1,16*16,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
if __name__ == "__main__":
unittest.main()

View File

@ -141,18 +141,6 @@ pjmap = {
'extras': {'affine': 'None'},
'delete': ['track_running_stats'],
},
'InstanceNorm2d': {
'pytorch': {
'args': "num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False"
},
'jittor': {
'module': 'nn',
'name': 'InstanceNorm2d',
'args': 'num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True'
},
'links': {},
'extras': {'affine': 'None'},
},
'Dropout2d': {
'pytorch': {
'args': 'p=0.5, inplace=False',
@ -214,6 +202,18 @@ pjmap = {
'links': {'tensor': 'var'},
'extras': {},
},
'uniform_': {
'pytorch': {
'args': "tensor, a=0.0, b=1.0",
},
'jittor': {
'module': 'init',
'name': 'uniform_',
'args': 'var, low, high'
},
'links': {'tensor': 'var', 'a': 'low', 'b': 'high'},
'extras': {},
},
'cat': {
'pytorch': {
'args': "tensors, dim=0, out=None",
@ -278,7 +278,6 @@ pjmap = {
'links': {},
'extras': {},
},
# 好像不需要如果一毛一样的话
'view': {
'pytorch': {
'prefix': [],