mirror of https://github.com/Jittor/Jittor
fix densenet bug
This commit is contained in:
parent
0068580ecb
commit
6b471e5c8e
|
@ -516,8 +516,11 @@ class Module:
|
|||
end = 0
|
||||
for k in key_:
|
||||
if isinstance(v, nn.Sequential):
|
||||
v = v[k]
|
||||
if v is None:
|
||||
if (k in v.layers):
|
||||
v = v[k]
|
||||
elif k.isdigit() and (ori_int(k) in v.layers):
|
||||
v = v[ori_int(k)]
|
||||
else:
|
||||
end=1
|
||||
break
|
||||
else:
|
||||
|
|
|
@ -140,5 +140,5 @@ class DenseNet(nn.Module):
|
|||
features = self.features(x)
|
||||
out = nn.relu(features)
|
||||
out = jt.pool.pool(out, kernel_size=7, op="mean", stride=1).reshape([features.shape[0], -1])
|
||||
out = jt.sigmoid(self.classifier(out))
|
||||
out = self.classifier(out)
|
||||
return out
|
||||
|
|
|
@ -26,7 +26,6 @@ class test_models(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(self):
|
||||
self.models = [
|
||||
'inception_v3',
|
||||
'squeezenet1_0',
|
||||
'squeezenet1_1',
|
||||
'alexnet',
|
||||
|
@ -59,8 +58,9 @@ class test_models(unittest.TestCase):
|
|||
'shufflenet_v2_x2_0',
|
||||
"densenet121",
|
||||
"densenet161",
|
||||
"densenet169",
|
||||
'inception_v3',
|
||||
]
|
||||
self.models = ["densenet169"]
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
@ -81,7 +81,6 @@ class test_models(unittest.TestCase):
|
|||
pytorch_test_img = to_cuda(torch.Tensor(test_img))
|
||||
jittor_test_img = jt.array(test_img)
|
||||
for test_model in self.models:
|
||||
print("test model", test_model)
|
||||
if test_model == "inception_v3":
|
||||
test_img = np.random.random((bs,3,300,300)).astype('float32')
|
||||
pytorch_test_img = to_cuda(torch.Tensor(test_img))
|
||||
|
@ -101,10 +100,6 @@ class test_models(unittest.TestCase):
|
|||
y = jittor_result.data + 1
|
||||
relative_error = abs(x - y) / abs(y)
|
||||
diff = relative_error.mean()
|
||||
print(pytorch_result.shape, jittor_result.shape)
|
||||
print(pytorch_result)
|
||||
print(jittor_result)
|
||||
print(pytorch_result.detach().cpu().numpy() - jittor_result.data)
|
||||
assert diff < threshold, f"[*] {test_model} forward fails..., Relative Error: {diff}"
|
||||
print(f"[*] {test_model} forword passes with Relative Error {diff}")
|
||||
jt.clean()
|
||||
|
|
Loading…
Reference in New Issue