fix init bug

This commit is contained in:
guowei yang 2020-05-01 23:55:14 +08:00
parent 78cc826756
commit d98958e627
3 changed files with 18 additions and 8 deletions

View File

@ -15,10 +15,14 @@ __all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
return conv
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
return conv
class BasicBlock(nn.Module):
expansion = 1
@ -102,6 +106,7 @@ class ResNet(nn.Module):
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.Relu()
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')

View File

@ -338,9 +338,14 @@ class Conv(Module):
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
if bias:
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
fan=1
for i in self.weight.shape[1:]:
fan *= i
bound = 1 / math.sqrt(fan)
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
else:
self.bias = None

View File

@ -20,7 +20,7 @@ except Exception as e:
def get_error(a, b):
return np.abs(a-b) / max(np.abs(a), np.abs(b), 1e-5) , np.abs(a-b)
def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5):
def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5, mean_atol=1e-5):
pa = [ p for p in jt_mod.parameters() if not p.is_stop_grad() ]
pb = list(torch_mod.parameters())
assert len(pa) == len(pb)
@ -36,7 +36,7 @@ def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5):
print("compare std error", stda, stdb, r_err, a_err, a.name(), a.shape)
r_err, a_err = get_error(meana, meanb)
if r_err > rtol and a_err > atol:
if r_err > rtol and a_err > mean_atol:
error_count += 1
print("compare mean error", meana, meanb, r_err, a_err, a.name(), a.shape)
assert error_count == 0
@ -50,10 +50,10 @@ class TestInit(unittest.TestCase):
torch.manual_seed(0)
def test_conv(self):
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3))
check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-3)
def test_resnet(self):
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2)
check(models.resnet152(), torchvision.models.resnet152(), rtol=2e-2, mean_atol=1e-2)
if __name__ == "__main__":
unittest.main()