fix: support reshape empty var with uncertain dimension

This commit is contained in:
lzhengning 2022-09-16 15:36:33 +08:00
parent e661f19e20
commit 05ed6c7e34
2 changed files with 13 additions and 2 deletions

View File

@ -44,8 +44,12 @@ void ReshapeOp::infer_shape() {
if (uncertain_dim == 0) { if (uncertain_dim == 0) {
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size"; CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
} else { } else {
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items; if (x_items == 0) {
uncertain_dim = x_items / y_items; uncertain_dim = 0;
} else {
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
uncertain_dim = x_items / y_items;
}
yshape.clear(); yshape.clear();
for (auto a : shape) for (auto a : shape)
yshape.push_back(a<0 ? uncertain_dim : a); yshape.push_back(a<0 ? uncertain_dim : a);

View File

@ -75,6 +75,13 @@ class TestReshapeOp(unittest.TestCase):
a = jt.zeros(10) a = jt.zeros(10)
b = a.reshape(a.shape) b = a.reshape(a.shape)
def test_reshape_empty(self):
a = jt.array([])
b = a.reshape(0, 1, 2)
assert b.shape == [0, 1, 2]
b = a.reshape(0, -1)
assert b.shape == [0, 0]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()