mirror of https://github.com/Jittor/Jittor
update
This commit is contained in:
parent
264148a017
commit
4ff60ee04c
|
@ -285,39 +285,6 @@ class BatchNorm(Module):
|
|||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
|
@ -377,7 +344,9 @@ class Conv(Module):
|
|||
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
shape = self.weight.shape
|
||||
bound = 1 / math.sqrt(shape[1] * shape[2] * shape[3])
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -433,7 +402,7 @@ class Conv(Module):
|
|||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
return y
|
||||
|
||||
|
||||
class ConvTranspose(Module):
|
||||
|
|
|
@ -20,39 +20,46 @@ except:
|
|||
torch = None
|
||||
tnn = None
|
||||
|
||||
def check_equal(a, b):
|
||||
eps = 1e-1 # icc error almost reaches 1e-1
|
||||
relative_error = (abs(a - b) / abs(b + 1)).mean()
|
||||
print(f"relative_error: {relative_error}")
|
||||
return relative_error < eps
|
||||
def check_equal(arr, j_layer, p_layer, threshold=1e-5):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
|
||||
|
||||
class TestBatchNorm(unittest.TestCase):
|
||||
def test_batchnorm(self):
|
||||
# ***************************************************************
|
||||
# Define jittor & pytorch array
|
||||
# Test BatchNorm Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
# ***************************************************************
|
||||
# Test InstanceNorm2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.InstanceNorm2d(10)(pytorch_arr)
|
||||
jittor_result = jnn.InstanceNorm2d(10)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
check_equal(arr, jnn.BatchNorm(10, is_train=True), tnn.BatchNorm2d(10))
|
||||
|
||||
class Model(tnn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = tnn.BatchNorm2d(10)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
model = Model()
|
||||
model.eval()
|
||||
check_equal(arr, jnn.BatchNorm(10, is_train=False), model)
|
||||
|
||||
# ***************************************************************
|
||||
# Define jittor & pytorch array
|
||||
# Test BatchNorm1d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,1000)
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
# ***************************************************************
|
||||
# Test InstanceNorm2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.BatchNorm1d(1000)(pytorch_arr)
|
||||
jittor_result = jnn.BatchNorm1d(1000)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
check_equal(arr, jnn.BatchNorm1d(1000), tnn.BatchNorm1d(1000), 1e-3)
|
||||
|
||||
class Model(tnn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = tnn.BatchNorm1d(1000)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
model = Model()
|
||||
model.eval()
|
||||
check_equal(arr, jnn.BatchNorm1d(1000, is_train=False), model, 1e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -20,56 +20,42 @@ except:
|
|||
torch = None
|
||||
tnn = None
|
||||
|
||||
def check_equal(a, b):
|
||||
eps = 1e-1 # icc error almost reaches 1e-1
|
||||
relative_error = (abs(a - b) / abs(b + 1)).mean()
|
||||
print(f"relative_error: {relative_error}")
|
||||
return relative_error < eps
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
# ***************************************************************
|
||||
# Define jittor & pytorch array
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
# ***************************************************************
|
||||
# Test ReplicationPad2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.ReplicationPad2d(10)(pytorch_arr)
|
||||
jittor_result = jnn.ReplicationPad2d(10)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.ReplicationPad2d((1,23,4,5))(pytorch_arr)
|
||||
jittor_result = jnn.ReplicationPad2d((1,23,4,5))(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ConstantPad2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.ConstantPad2d(10,-2)(pytorch_arr)
|
||||
jittor_result = jnn.ConstantPad2d(10,-2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.ConstantPad2d((2,3,34,1),10.2)(pytorch_arr)
|
||||
jittor_result = jnn.ConstantPad2d((2,3,34,1),10.2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ZeroPad2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.ZeroPad2d(1)(pytorch_arr)
|
||||
jittor_result = jnn.ZeroPad2d(1)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.ZeroPad2d((2,3,34,1))(pytorch_arr)
|
||||
jittor_result = jnn.ZeroPad2d((2,3,34,1))(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1))
|
||||
check_equal(arr, jnn.ZeroPad2d((2,3,34,1)), tnn.ZeroPad2d((2,3,34,1)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ReflectionPad2d Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.ReflectionPad2d(20)(pytorch_arr)
|
||||
jittor_result = jnn.ReflectionPad2d(20)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.ReflectionPad2d((2,3,34,1))(pytorch_arr)
|
||||
jittor_result = jnn.ReflectionPad2d((2,3,34,1))(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
|
||||
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -20,52 +20,43 @@ except:
|
|||
torch = None
|
||||
tnn = None
|
||||
|
||||
def check_equal(a, b):
|
||||
eps = 1e-1 # icc error almost reaches 1e-1
|
||||
relative_error = (abs(a - b) / abs(b + 1)).mean()
|
||||
print(f"relative_error: {relative_error}")
|
||||
return relative_error < eps
|
||||
def check_equal(arr, j_layer, p_layer):
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
|
||||
class TestRelu(unittest.TestCase):
|
||||
def test_relu(self):
|
||||
# ***************************************************************
|
||||
# Define jittor & pytorch array
|
||||
# Test ReLU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
jittor_arr = jt.array(arr)
|
||||
pytorch_arr = torch.Tensor(arr)
|
||||
check_equal(arr, jnn.ReLU(), tnn.ReLU())
|
||||
|
||||
# ***************************************************************
|
||||
# Test PReLU Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.PReLU(10, 2)(pytorch_arr)
|
||||
jittor_result = jnn.PReLU(10, 2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.PReLU(10, -0.2)(pytorch_arr)
|
||||
jittor_result = jnn.PReLU(10, -0.2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.PReLU(10, 99.9)(pytorch_arr)
|
||||
jittor_result = jnn.PReLU(10, 99.9)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.PReLU(), tnn.PReLU())
|
||||
check_equal(arr, jnn.PReLU(10, 99.9), tnn.PReLU(10, 99.9))
|
||||
check_equal(arr, jnn.PReLU(10, 2), tnn.PReLU(10, 2))
|
||||
check_equal(arr, jnn.PReLU(10, -0.2), tnn.PReLU(10, -0.2))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ReLU6 Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.ReLU6()(pytorch_arr)
|
||||
jittor_result = jnn.ReLU6()(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.ReLU6(), tnn.ReLU6())
|
||||
|
||||
# ***************************************************************
|
||||
# Test LeakyReLU Layer
|
||||
# ***************************************************************
|
||||
pytorch_result = tnn.LeakyReLU(2)(pytorch_arr)
|
||||
jittor_result = jnn.LeakyReLU(2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.LeakyReLU()(pytorch_arr)
|
||||
jittor_result = jnn.LeakyReLU()(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
pytorch_result = tnn.LeakyReLU(99.9)(pytorch_arr)
|
||||
jittor_result = jnn.LeakyReLU(99.9)(jittor_arr)
|
||||
assert check_equal(pytorch_result.detach().numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
||||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -84,12 +84,6 @@ def test_case(box_num, out_size, time_limit):
|
|||
assert fused_op_num == 1, fused_op_num
|
||||
assert t <= time_limit, t
|
||||
|
||||
def check_equal(a, b):
|
||||
eps = 1e-1 # icc error almost reaches 1e-1
|
||||
relative_error = (abs(a - b) / abs(b + 1)).mean()
|
||||
print(f"relative_error: {relative_error}")
|
||||
return relative_error < eps
|
||||
|
||||
class TestResizeAndCrop(unittest.TestCase):
|
||||
def test(self):
|
||||
test_case(100, [224, 224], 0.45)
|
||||
|
@ -109,11 +103,11 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
# ***************************************************************
|
||||
pytorch_result = tnn.Upsample(scale_factor=2)(pytorch_arr)
|
||||
jittor_result = jnn.Upsample(scale_factor=2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
|
||||
|
||||
pytorch_result = tnn.Upsample(scale_factor=0.2)(pytorch_arr)
|
||||
jittor_result = jnn.Upsample(scale_factor=0.2)(jittor_arr)
|
||||
assert check_equal(pytorch_result.numpy(), jittor_result.numpy()), f"{pytorch_result.mean()} || {jittor_result.mean()}"
|
||||
assert np.allclose(pytorch_result.numpy(), jittor_result.numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -138,7 +138,7 @@ class RandomCrop:
|
|||
def __call__(self, img:Image.Image):
|
||||
width, height = img.size
|
||||
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop"
|
||||
top = np.random.randint(0,height-self.size[0])
|
||||
left = np.random.randint(0,width-self.size[1])
|
||||
top = np.random.randint(0,height-self.size[0]+1)
|
||||
left = np.random.randint(0,width-self.size[1]+1)
|
||||
return crop(img, top, left, self.size[0], self.size[1])
|
||||
|
Loading…
Reference in New Issue