mirror of https://github.com/Jittor/Jittor
update
This commit is contained in:
parent
4ff60ee04c
commit
d74dc1c148
|
@ -602,6 +602,21 @@ class ReplicationPad2d(Module):
|
|||
y_idx[j,i] = b
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
r = self.upscale_factor
|
||||
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
||||
"i0",
|
||||
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
|
||||
f"i2/{r}",
|
||||
f"i3/{r}"
|
||||
])
|
||||
|
||||
class Tanh(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer, threshold=1e-5):
|
||||
jittor_arr = jt.array(arr)
|
||||
|
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer, threshold=1e-5):
|
|||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestBatchNorm(unittest.TestCase):
|
||||
def test_batchnorm(self):
|
||||
# ***************************************************************
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
|
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer):
|
|||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
# ***************************************************************
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
|
@ -27,6 +28,7 @@ def check_equal(arr, j_layer, p_layer):
|
|||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestRelu(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
# ***************************************************************
|
||||
|
|
|
@ -20,6 +20,7 @@ try:
|
|||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
mid = 0
|
||||
if os.uname()[1] == "jittor-ce":
|
||||
|
@ -84,6 +85,13 @@ def test_case(box_num, out_size, time_limit):
|
|||
assert fused_op_num == 1, fused_op_num
|
||||
assert t <= time_limit, t
|
||||
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
class TestResizeAndCrop(unittest.TestCase):
|
||||
def test(self):
|
||||
test_case(100, [224, 224], 0.45)
|
||||
|
@ -92,22 +100,15 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
def test_upsample(self):
|
||||
# ***************************************************************
|
||||
# Define jittor & pytorch array
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
# ***************************************************************
|
||||
# Test Upsample Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.Upsample(scale_factor=2)(pytorch_arr)
|
||||
jittor_result = jnn.Upsample(scale_factor=2)(jittor_arr)
|
||||
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
|
||||
|
||||
pytorch_result = tnn.Upsample(scale_factor=0.2)(pytorch_arr)
|
||||
jittor_result = jnn.Upsample(scale_factor=0.2)(jittor_arr)
|
||||
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(16,16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
arr = np.random.randn(1,16*16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -141,18 +141,6 @@ pjmap = {
|
|||
'extras': {'affine': 'None'},
|
||||
'delete': ['track_running_stats'],
|
||||
},
|
||||
'InstanceNorm2d': {
|
||||
'pytorch': {
|
||||
'args': "num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False"
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'InstanceNorm2d',
|
||||
'args': 'num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True'
|
||||
},
|
||||
'links': {},
|
||||
'extras': {'affine': 'None'},
|
||||
},
|
||||
'Dropout2d': {
|
||||
'pytorch': {
|
||||
'args': 'p=0.5, inplace=False',
|
||||
|
@ -214,6 +202,18 @@ pjmap = {
|
|||
'links': {'tensor': 'var'},
|
||||
'extras': {},
|
||||
},
|
||||
'uniform_': {
|
||||
'pytorch': {
|
||||
'args': "tensor, a=0.0, b=1.0",
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'init',
|
||||
'name': 'uniform_',
|
||||
'args': 'var, low, high'
|
||||
},
|
||||
'links': {'tensor': 'var', 'a': 'low', 'b': 'high'},
|
||||
'extras': {},
|
||||
},
|
||||
'cat': {
|
||||
'pytorch': {
|
||||
'args': "tensors, dim=0, out=None",
|
||||
|
@ -278,7 +278,6 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
# 好像不需要如果一毛一样的话
|
||||
'view': {
|
||||
'pytorch': {
|
||||
'prefix': [],
|
||||
|
|
Loading…
Reference in New Issue