mirror of https://github.com/Jittor/Jittor
fix init bug
This commit is contained in:
parent
78cc826756
commit
d98958e627
|
@ -15,10 +15,14 @@ __all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152
|
|||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -102,6 +106,7 @@ class ResNet(nn.Module):
|
|||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.Relu()
|
||||
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')
|
||||
|
|
|
@ -338,9 +338,14 @@ class Conv(Module):
|
|||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
fan *= i
|
||||
bound = 1 / math.sqrt(fan)
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ except Exception as e:
|
|||
def get_error(a, b):
|
||||
return np.abs(a-b) / max(np.abs(a), np.abs(b), 1e-5) , np.abs(a-b)
|
||||
|
||||
def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5):
|
||||
def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5, mean_atol=1e-5):
|
||||
pa = [ p for p in jt_mod.parameters() if not p.is_stop_grad() ]
|
||||
pb = list(torch_mod.parameters())
|
||||
assert len(pa) == len(pb)
|
||||
|
@ -36,7 +36,7 @@ def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5):
|
|||
print("compare std error", stda, stdb, r_err, a_err, a.name(), a.shape)
|
||||
|
||||
r_err, a_err = get_error(meana, meanb)
|
||||
if r_err > rtol and a_err > atol:
|
||||
if r_err > rtol and a_err > mean_atol:
|
||||
error_count += 1
|
||||
print("compare mean error", meana, meanb, r_err, a_err, a.name(), a.shape)
|
||||
assert error_count == 0
|
||||
|
@ -50,10 +50,10 @@ class TestInit(unittest.TestCase):
|
|||
torch.manual_seed(0)
|
||||
|
||||
def test_conv(self):
|
||||
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3))
|
||||
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-3)
|
||||
|
||||
def test_resnet(self):
|
||||
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2)
|
||||
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2, mean_atol=1e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue