mirror of https://github.com/Jittor/Jittor
fix zero dim broadcast
This commit is contained in:
parent
f75c26ab55
commit
7e1678c7ea
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.9'
|
||||
__version__ = '1.1.7.10'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -133,6 +133,11 @@ class TestBroadcastToOpMisc(unittest.TestCase):
|
|||
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
|
||||
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
|
||||
|
||||
def test_zero_dim(self):
|
||||
a = jt.array(1.0)
|
||||
b = a.broadcast([0])
|
||||
assert b.shape == [0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -63,7 +63,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
|
|||
set_type(OpType::broadcast);
|
||||
CHECKop(shape.size(),>,0u) << "Number of shape should greater than 0.";
|
||||
for (auto v : shape)
|
||||
CHECKop(v,>,0u) << "Shape should greater than 0.";
|
||||
CHECKop(v,>=,0u) << "Shape should greater than 0.";
|
||||
z = create_output(nullptr, x->dtype());
|
||||
bcast_mask = 0;
|
||||
keepdims_mask = 0;
|
||||
|
@ -78,7 +78,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
|
|||
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
|
||||
if (x->shape.size() < shape.size()) return true;
|
||||
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
|
||||
if (x->shape[j]< 0 || x->shape[j] < shape[i]) return true;
|
||||
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -57,10 +57,10 @@ void ReindexReduceOp::infer_shape() {
|
|||
CHECKop(shape.size(),==,indexes.size()) << "Number of shape and indexes should be the same.";
|
||||
CHECK(shape.size()) << "Number of shape should greater than 0.";
|
||||
for (auto v : shape)
|
||||
CHECKop(v,>,0u) << "Shape should greater than 0.";
|
||||
CHECKop(v,>=,0u) << "Shape should greater than 0.";
|
||||
x->set_shape(shape);
|
||||
CHECKop(x->size,>,0u);
|
||||
CHECKop(y->size,>,0u);
|
||||
CHECKop(x->size,>=,0u);
|
||||
CHECKop(y->size,>=,0u);
|
||||
}
|
||||
|
||||
void ReindexReduceOp::jit_prepare() {
|
||||
|
|
Loading…
Reference in New Issue