fix zero dim broadcast

This commit is contained in:
Dun Liang 2020-08-17 20:20:31 +08:00
parent f75c26ab55
commit 7e1678c7ea
4 changed files with 11 additions and 6 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.7.9'
__version__ = '1.1.7.10'
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -133,6 +133,11 @@ class TestBroadcastToOpMisc(unittest.TestCase):
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
def test_zero_dim(self):
a = jt.array(1.0)
b = a.broadcast([0])
assert b.shape == [0]
if __name__ == "__main__":
unittest.main()

View File

@ -63,7 +63,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
set_type(OpType::broadcast);
CHECKop(shape.size(),>,0u) << "Number of shape should greater than 0.";
for (auto v : shape)
CHECKop(v,>,0u) << "Shape should greater than 0.";
CHECKop(v,>=,0u) << "Shape should greater than 0.";
z = create_output(nullptr, x->dtype());
bcast_mask = 0;
keepdims_mask = 0;
@ -78,7 +78,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
if (x->shape.size() < shape.size()) return true;
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
if (x->shape[j]< 0 || x->shape[j] < shape[i]) return true;
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
return false;
}

View File

@ -57,10 +57,10 @@ void ReindexReduceOp::infer_shape() {
CHECKop(shape.size(),==,indexes.size()) << "Number of shape and indexes should be the same.";
CHECK(shape.size()) << "Number of shape should greater than 0.";
for (auto v : shape)
CHECKop(v,>,0u) << "Shape should greater than 0.";
CHECKop(v,>=,0u) << "Shape should greater than 0.";
x->set_shape(shape);
CHECKop(x->size,>,0u);
CHECKop(y->size,>,0u);
CHECKop(x->size,>=,0u);
CHECKop(y->size,>=,0u);
}
void ReindexReduceOp::jit_prepare() {