fix densenet vary shape input

This commit is contained in:
Dun Liang 2021-02-25 11:39:55 +08:00
parent 3df3bfc35c
commit 73aa382627
4 changed files with 4 additions and 3 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.33'
__version__ = '1.2.2.34'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -140,6 +140,6 @@ class DenseNet(nn.Module):
def execute(self, x):
features = self.features(x)
out = nn.relu(features)
out = jt.pool.pool(out, kernel_size=7, op="mean", stride=1).reshape([features.shape[0], -1])
out = out.mean([2,3])
out = self.classifier(out)
return out

View File

@ -27,7 +27,7 @@ def matmul_transpose(a, b):
returns a * b^T
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1]
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
if len(a.shape)>2:
aa = a.reshape((-1, a.shape[-1]))
cc = matmul_transpose(aa, b)

View File

@ -78,6 +78,7 @@ class test_models(unittest.TestCase):
# Define numpy input image
bs = 1
test_img = np.random.random((bs,3,224,224)).astype('float32')
# test_img = np.random.random((bs,3,280,280)).astype('float32')
# Define pytorch & jittor input image
pytorch_test_img = to_cuda(torch.Tensor(test_img))
jittor_test_img = jt.array(test_img)