mirror of https://github.com/Jittor/Jittor
fix densenet vary shape input
This commit is contained in:
parent
3df3bfc35c
commit
73aa382627
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.33'
|
||||
__version__ = '1.2.2.34'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -140,6 +140,6 @@ class DenseNet(nn.Module):
|
|||
def execute(self, x):
|
||||
features = self.features(x)
|
||||
out = nn.relu(features)
|
||||
out = jt.pool.pool(out, kernel_size=7, op="mean", stride=1).reshape([features.shape[0], -1])
|
||||
out = out.mean([2,3])
|
||||
out = self.classifier(out)
|
||||
return out
|
||||
|
|
|
@ -27,7 +27,7 @@ def matmul_transpose(a, b):
|
|||
returns a * b^T
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
|
||||
if len(a.shape)>2:
|
||||
aa = a.reshape((-1, a.shape[-1]))
|
||||
cc = matmul_transpose(aa, b)
|
||||
|
|
|
@ -78,6 +78,7 @@ class test_models(unittest.TestCase):
|
|||
# Define numpy input image
|
||||
bs = 1
|
||||
test_img = np.random.random((bs,3,224,224)).astype('float32')
|
||||
# test_img = np.random.random((bs,3,280,280)).astype('float32')
|
||||
# Define pytorch & jittor input image
|
||||
pytorch_test_img = to_cuda(torch.Tensor(test_img))
|
||||
jittor_test_img = jt.array(test_img)
|
||||
|
|
Loading…
Reference in New Issue