From 05ed6c7e34b25199e6d373d95555e5bcba3bc95c Mon Sep 17 00:00:00 2001 From: lzhengning Date: Fri, 16 Sep 2022 15:36:33 +0800 Subject: [PATCH] fix: support reshape empty var with uncertain dimension --- python/jittor/src/ops/reshape_op.cc | 8 ++++++-- python/jittor/test/test_reshape.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/jittor/src/ops/reshape_op.cc b/python/jittor/src/ops/reshape_op.cc index a02cae34..50b3f97a 100644 --- a/python/jittor/src/ops/reshape_op.cc +++ b/python/jittor/src/ops/reshape_op.cc @@ -44,8 +44,12 @@ void ReshapeOp::infer_shape() { if (uncertain_dim == 0) { CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size"; } else { - CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items; - uncertain_dim = x_items / y_items; + if (x_items == 0) { + uncertain_dim = 0; + } else { + CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items; + uncertain_dim = x_items / y_items; + } yshape.clear(); for (auto a : shape) yshape.push_back(a<0 ? uncertain_dim : a); diff --git a/python/jittor/test/test_reshape.py b/python/jittor/test/test_reshape.py index 87f20bb5..2269592c 100644 --- a/python/jittor/test/test_reshape.py +++ b/python/jittor/test/test_reshape.py @@ -75,6 +75,13 @@ class TestReshapeOp(unittest.TestCase): a = jt.zeros(10) b = a.reshape(a.shape) + def test_reshape_empty(self): + a = jt.array([]) + b = a.reshape(0, 1, 2) + assert b.shape == [0, 1, 2] + b = a.reshape(0, -1) + assert b.shape == [0, 0] + if __name__ == "__main__": unittest.main() \ No newline at end of file